"""Benchmark the latency of processing a single batch of requests."""
import argparse
import dataclasses
import json
import time
from pathlib import Path
from typing import List, Optional
import math
import os
os.environ['CN_NOTIFIER_POOL_MAX'] = "1000"

import numpy as np
import torch
from tqdm import tqdm
from common import init_logger
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.utils import FlexibleArgumentParser
from vllm_mlu._mlu_utils import USE_PAGED

logger = init_logger(__name__)

def main(args: argparse.Namespace):
    print(args)

    engine_args = EngineArgs.from_cli_args(args)

    # NOTE(woosuk): If the request cannot be processed in a single batch,
    # the engine will automatically process the request in multiple batches.
    engine_args_dict_org = dataclasses.asdict(engine_args)
    engine_args_dict = {
        **engine_args_dict_org,
        **{
            k: v
            for k, v in engine_args.__dict__.items() if k not in engine_args_dict_org
        }
    }

    llm = LLM(**engine_args_dict,
              enable_context_mlugraph=True,
              context_batch_size_to_capture=args.batch_size,
              context_seq_len_to_capture=args.input_len)

    num_gpu_block          = llm.llm_engine.cache_config.num_gpu_blocks
    block_size             = llm.llm_engine.cache_config.block_size
    max_num_batched_tokens = llm.llm_engine.scheduler_config.max_num_batched_tokens
    batched_input_tokens   = args.input_len * args.batch_size
    batched_tokens_align   = math.ceil((args.input_len + args.output_len) / \
                             block_size) * block_size * args.batch_size
    if not args.enable_chunked_prefill :
        if max_num_batched_tokens < batched_input_tokens :
            logger.error(f"The batch({args.batch_size}) * input length({args.input_len}) ="
                f" ({batched_input_tokens}) is larger than "
                f" max_num_batched_tokens({max_num_batched_tokens})")
            logger.info(f"Try --max-num-batched-tokens ({batched_input_tokens})")
            return
        elif num_gpu_block * block_size < batched_tokens_align :
            logger.error(f"Ceil of batch({args.batch_size}) * (input length"
                f" ({args.input_len}) + output length({args.output_len})) ="
                f" ({batched_tokens_align}) is larger than"
                f" mlu blocks({num_gpu_block}) * block_size({block_size}) ="
                f" ({num_gpu_block * block_size}) can hold max tokens.")
            if not USE_PAGED :
                logger.info("Try reduce block_size to make mlu blocks greater than batch,"
                    " or try increase -tp to get more mlu blocks.")
            else :
                logger.info("Try increase -tp to get more mlu blocks.")
            return
    # Generate a warning if the sum of the input length and output length
    # is less than the maximum model length, as only the first
    # `max_model_len` will be processed.
    max_length = args.input_len + args.output_len
    max_model_len = llm.llm_engine.model_config.max_model_len
    if max_length > max_model_len:
        logger.warning(
            f"The sum of input length({args.input_len}) and output"
            f" length({args.output_len}) is larger than max model"
            f" length({max_model_len})")

    sampling_params = SamplingParams(
        n=args.n,
        temperature=1.0,
        top_p=1.0,
        ignore_eos=True,
        max_tokens=args.output_len,
    )
    print(sampling_params)
    dummy_prompt_token_ids = np.random.randint(10000,
                                               size=(args.batch_size,
                                                     args.input_len))
    dummy_prompts: List[PromptType] = [{
        "prompt_token_ids": batch
    } for batch in dummy_prompt_token_ids.tolist()]

    def run_to_completion(profile_dir: Optional[str] = None):
        if profile_dir:
            with torch.profiler.profile(
                    activities=[
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
                        str(profile_dir))) as p:
                llm.generate(dummy_prompts,
                             sampling_params=sampling_params,
                             use_tqdm=False)
            print(p.key_averages())
        else:
            start_time = time.perf_counter()
            llm.generate(dummy_prompts,
                         sampling_params=sampling_params,
                         use_tqdm=False)
            end_time = time.perf_counter()
            latency = end_time - start_time
            return latency

    print("Warming up...")
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
        run_to_completion(profile_dir=None)

    if args.profile:
        profile_dir = args.profile_result_dir
        if not profile_dir:
            profile_dir = Path(
                "."
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
        print(f"Profiling (results will be saved to '{profile_dir}')...")
        run_to_completion(profile_dir=profile_dir)
        return

    # Benchmark.
    latencies = []
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
        latencies.append(run_to_completion(profile_dir=None))
        if args.show_per_iter:
            llm.get_metrics(args.num_iters_warmup,
                            args.only_average,
                            args.input_len,
                            args.output_len,
                            args.tensor_parallel_size,
                            args.quantization,
                            llm.dump_info,
                            show_per_iter=args.show_per_iter)
    latencies = np.array(latencies)
    percentages = [10, 25, 50, 75, 90, 99]
    percentiles = np.percentile(latencies, percentages)
    print(f'Avg latency: {np.mean(latencies)} seconds')
    for percentage, percentile in zip(percentages, percentiles):
        print(f'{percentage}% percentile latency: {percentile} seconds')

    # Output JSON results if specified
    if args.output_json:
        results = {
            "avg_latency": np.mean(latencies),
            "latencies": latencies.tolist(),
            "percentiles": dict(zip(percentages, percentiles.tolist())),
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)

    llm.get_metrics(args.num_iters_warmup,
                    args.only_average,
                    args.input_len,
                    args.output_len,
                    args.tensor_parallel_size,
                    args.quantization,
                    llm.dump_info)


if __name__ == '__main__':
    parser = FlexibleArgumentParser(
        description='Benchmark the latency of processing a single batch of '
        'requests till completion.')
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--n',
                        type=int,
                        default=1,
                        help='Number of generated sequences per prompt.')
    parser.add_argument('--use-beam-search', action='store_true')
    parser.add_argument('--num-iters-warmup',
                        type=int,
                        default=10,
                        help='Number of iterations to run for warmup.')
    parser.add_argument('--num-iters',
                        type=int,
                        default=30,
                        help='Number of iterations to run.')
    parser.add_argument(
        '--profile',
        action='store_true',
        help='profile the generation process of a single batch')
    parser.add_argument(
        '--profile-result-dir',
        type=str,
        default=None,
        help=('path to save the pytorch profiler output. Can be visualized '
              'with ui.perfetto.dev or Tensorboard.'))
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the latency results in JSON format.')
    parser.add_argument('--only-average',
                        action='store_true',
                        default=False,
                        help=(
                            'Show all iteration metrics or average metrics.'
                        ))
    parser.add_argument("--show-per-iter",
                        action='store_true',
                        help='If true, show metrics data per iteration.')
    parser = EngineArgs.add_cli_args(parser)
    args = parser.parse_args()
    main(args)
